Skip to content

Commit 62d2afc

Browse files
authored
Use seed in last dense layer of DeepFM (#255)
* Use seed in last dense layer of DeepFM
1 parent f88c8ab commit 62d2afc

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

deepctr/estimator/models/deepfm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
import tensorflow as tf
1212

13+
from tensorflow.python.keras.initializers import glorot_normal
14+
1315
from ..feature_column import get_linear_logit, input_from_feature_columns
1416
from ..utils import deepctr_model_fn, DNN_SCOPE_NAME, variable_scope
1517
from ...layers.core import DNN
@@ -66,7 +68,7 @@ def _model_fn(features, labels, mode, config):
6668
dnn_output = DNN(dnn_hidden_units, dnn_activation, l2_reg_dnn, dnn_dropout,
6769
dnn_use_bn, seed)(dnn_input, training=train_flag)
6870
dnn_logit = tf.keras.layers.Dense(
69-
1, use_bias=False, activation=None)(dnn_output)
71+
1, use_bias=False, activation=None, kernel_initializer=glorot_normal(seed=seed))(dnn_output)
7072

7173
logits = linear_logits + fm_logit + dnn_logit
7274

deepctr/models/deepfm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
import tensorflow as tf
1414

15+
from tensorflow.python.keras.initializers import glorot_normal
16+
1517
from ..feature_column import build_input_features, get_linear_logit, DEFAULT_GROUP_NAME, input_from_feature_columns
1618
from ..layers.core import PredictionLayer, DNN
1719
from ..layers.interaction import FM
@@ -57,7 +59,7 @@ def DeepFM(linear_feature_columns, dnn_feature_columns, fm_group=[DEFAULT_GROUP_
5759
dnn_output = DNN(dnn_hidden_units, dnn_activation, l2_reg_dnn, dnn_dropout,
5860
dnn_use_bn, seed)(dnn_input)
5961
dnn_logit = tf.keras.layers.Dense(
60-
1, use_bias=False, activation=None)(dnn_output)
62+
1, use_bias=False, activation=None, kernel_initializer=glorot_normal(seed=seed))(dnn_output)
6163

6264
final_logit = add_func([linear_logit, fm_logit, dnn_logit])
6365

0 commit comments

Comments
 (0)