@@ -68,16 +68,16 @@ def _model_fn(features, labels, mode, config):
68
68
cross_out = CrossNet (cross_num , l2_reg = l2_reg_cross )(dnn_input )
69
69
stack_out = tf .keras .layers .Concatenate ()([cross_out , deep_out ])
70
70
final_logit = tf .keras .layers .Dense (
71
- 1 , use_bias = False , activation = None )(stack_out )
71
+ 1 , use_bias = False , kernel_initializer = tf . keras . initializers . glorot_normal ( seed ) )(stack_out )
72
72
elif len (dnn_hidden_units ) > 0 : # Only Deep
73
73
deep_out = DNN (dnn_hidden_units , dnn_activation , l2_reg_dnn , dnn_dropout ,
74
74
dnn_use_bn , seed )(dnn_input , training = train_flag )
75
75
final_logit = tf .keras .layers .Dense (
76
- 1 , use_bias = False , activation = None )(deep_out )
76
+ 1 , use_bias = False , kernel_initializer = tf . keras . initializers . glorot_normal ( seed ) )(deep_out )
77
77
elif cross_num > 0 : # Only Cross
78
78
cross_out = CrossNet (cross_num , l2_reg = l2_reg_cross )(dnn_input )
79
79
final_logit = tf .keras .layers .Dense (
80
- 1 , use_bias = False , activation = None )(cross_out )
80
+ 1 , use_bias = False , kernel_initializer = tf . keras . initializers . glorot_normal ( seed ) )(cross_out )
81
81
else : # Error
82
82
raise NotImplementedError
83
83
0 commit comments