Skip to content

Commit ef3eff6

Browse files
pandeconsciousHarshit Pande
andauthored
FEFM/DeepFEFM (#364)
add FEFM and DeepFEFM Co-authored-by: Harshit Pande <[email protected]>
1 parent c98a783 commit ef3eff6

File tree

12 files changed

+405
-6
lines changed

12 files changed

+405
-6
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ Let's [**Get Started!**](https://deepctr-doc.readthedocs.io/en/latest/Quick-Star
5858
| IFM | [IJCAI 2019][An Input-aware Factorization Machine for Sparse Prediction](https://www.ijcai.org/Proceedings/2019/0203.pdf) |
5959
| DCN V2 | [arxiv 2020][DCN V2: Improved Deep & Cross Network and Practical Lessons for Web-scale Learning to Rank Systems](https://arxiv.org/abs/2008.13535) |
6060
| DIFM | [IJCAI 2020][A Dual Input-aware Factorization Machine for CTR Prediction](https://www.ijcai.org/Proceedings/2020/0434.pdf) |
61+
| FEFM and DeepFEFM | [arxiv 2020][Field-Embedded Factorization Machines for Click-through rate prediction](https://arxiv.org/abs/2009.09931) |
6162

6263
## Citation
6364

deepctr/estimator/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@
1010
from .pnn import PNNEstimator
1111
from .wdl import WDLEstimator
1212
from .xdeepfm import xDeepFMEstimator
13+
from .deepfefm import DeepFEFMEstimator
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# -*- coding:utf-8 -*-
2+
"""
3+
Author:
4+
Harshit Pande
5+
6+
Reference:
7+
[1] Field-Embedded Factorization Machines for Click-through Rate Prediction]
8+
(https://arxiv.org/abs/2009.09931)
9+
10+
"""
11+
12+
import tensorflow as tf
13+
14+
from ..feature_column import get_linear_logit, input_from_feature_columns
15+
from ..utils import DNN_SCOPE_NAME, deepctr_model_fn, variable_scope
16+
from ...layers.core import DNN
17+
from ...layers.interaction import FEFMLayer
18+
from ...layers.utils import concat_func, add_func, combined_dnn_input, reduce_sum
19+
20+
21+
def DeepFEFMEstimator(linear_feature_columns, dnn_feature_columns, embedding_size=48,
22+
dnn_hidden_units=(1024, 1024, 1024), l2_reg_linear=0.000001, l2_reg_embedding_feat=0.00001,
23+
l2_reg_embedding_field=0.0000001, l2_reg_dnn=0, seed=1024, dnn_dropout=0.2,
24+
dnn_activation='relu', dnn_use_bn=False, task='binary', model_dir=None,
25+
config=None, linear_optimizer='Ftrl', dnn_optimizer='Adagrad', training_chief_hooks=None):
26+
"""Instantiates the DeepFEFM Network architecture or the shallow FEFM architecture (Ablation support not provided
27+
as estimator is meant for production, Ablation support provided in DeepFEFM implementation in models
28+
29+
:param linear_feature_columns: An iterable containing all the features used by linear part of the model.
30+
:param dnn_feature_columns: An iterable containing all the features used by deep part of the model.
31+
:param embedding_size: positive integer,sparse feature embedding_size
32+
:param dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of DNN
33+
:param l2_reg_linear: float. L2 regularizer strength applied to linear part
34+
:param l2_reg_embedding_feat: float. L2 regularizer strength applied to embedding vector of features
35+
:param l2_reg_embedding_field: float, L2 regularizer to field embeddings
36+
:param l2_reg_dnn: float. L2 regularizer strength applied to DNN
37+
:param seed: integer ,to use as random seed.
38+
:param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate.
39+
:param dnn_activation: Activation function to use in DNN
40+
:param dnn_use_bn: bool. Whether use BatchNormalization before activation or not in DNN
41+
:param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss
42+
:param model_dir: Directory to save model parameters, graph and etc. This can
43+
also be used to load checkpoints from the directory into a estimator
44+
to continue training a previously saved model.
45+
:param config: tf.RunConfig object to configure the runtime settings.
46+
:param linear_optimizer: An instance of `tf.Optimizer` used to apply gradients to
47+
the linear part of the model. Defaults to FTRL optimizer.
48+
:param dnn_optimizer: An instance of `tf.Optimizer` used to apply gradients to
49+
the deep part of the model. Defaults to Adagrad optimizer.
50+
:param training_chief_hooks: Iterable of `tf.train.SessionRunHook` objects to
51+
run on the chief worker during training.
52+
:return: A Tensorflow Estimator instance.
53+
"""
54+
55+
def _model_fn(features, labels, mode, config):
56+
train_flag = (mode == tf.estimator.ModeKeys.TRAIN)
57+
58+
linear_logits = get_linear_logit(features, linear_feature_columns, l2_reg_linear=l2_reg_linear)
59+
final_logit_components = [linear_logits]
60+
61+
with variable_scope(DNN_SCOPE_NAME):
62+
sparse_embedding_list, dense_value_list = input_from_feature_columns(features, dnn_feature_columns,
63+
l2_reg_embedding=l2_reg_embedding_feat)
64+
65+
fefm_interaction_embedding = FEFMLayer(num_fields=len(sparse_embedding_list), embedding_size=embedding_size,
66+
regularizer=l2_reg_embedding_field)(concat_func(sparse_embedding_list, axis=1))
67+
68+
fefm_logit = tf.keras.layers.Lambda(lambda x: reduce_sum(x, axis=1, keep_dims=True))(fefm_interaction_embedding)
69+
70+
final_logit_components.append(fefm_logit)
71+
72+
if dnn_hidden_units:
73+
dnn_input = combined_dnn_input(sparse_embedding_list, dense_value_list)
74+
dnn_input = concat_func([dnn_input, fefm_interaction_embedding], axis=1)
75+
76+
dnn_output = DNN(dnn_hidden_units, dnn_activation, l2_reg_dnn, dnn_dropout, dnn_use_bn, seed=seed)(
77+
dnn_input, training=train_flag)
78+
79+
dnn_logit = tf.keras.layers.Dense(
80+
1, use_bias=False, kernel_initializer=tf.keras.initializers.glorot_normal(seed))(dnn_output)
81+
82+
final_logit_components.append(dnn_logit)
83+
84+
logits = add_func(final_logit_components)
85+
86+
return deepctr_model_fn(features, mode, logits, labels, task, linear_optimizer, dnn_optimizer,
87+
training_chief_hooks=training_chief_hooks)
88+
89+
return tf.estimator.Estimator(_model_fn, model_dir=model_dir, config=config)
90+
91+
92+

deepctr/feature_column.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def get_linear_logit(features, feature_columns, units=1, use_bias=False, seed=10
181181
dense_input = concat_func(dense_input_list)
182182
linear_logit = Linear(l2_reg, mode=1, use_bias=use_bias, seed=seed)(dense_input)
183183
else: #empty feature_columns
184-
return Lambda(lambda x: tf.constant([[0.0]]))(features.values()[0])
184+
return Lambda(lambda x: tf.constant([[0.0]]))(list(features.values())[0])
185185
linear_logit_list.append(linear_logit)
186186

187187
return concat_func(linear_logit_list)

deepctr/layers/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
from .interaction import (CIN, FM, AFMLayer, BiInteractionPooling, CrossNet, CrossNetMix,
66
InnerProductLayer, InteractingLayer,
77
OutterProductLayer, FGCNNLayer, SENETLayer, BilinearInteraction,
8-
FieldWiseBiInteraction, FwFMLayer)
8+
FieldWiseBiInteraction, FwFMLayer, FEFMLayer)
99
from .normalization import LayerNormalization
1010
from .sequence import (AttentionSequencePoolingLayer, BiasEncoding, BiLSTM,
1111
KMaxPooling, SequencePoolingLayer, WeightedSequenceLayer,
1212
Transformer, DynamicGRU)
13-
from .utils import NoMask, Hash, Linear, Add, combined_dnn_input, softmax
13+
14+
from .utils import NoMask, Hash, Linear, Add, combined_dnn_input, softmax, reduce_sum
1415

1516
custom_objects = {'tf': tf,
1617
'InnerProductLayer': InnerProductLayer,
@@ -45,4 +46,6 @@
4546
'FieldWiseBiInteraction': FieldWiseBiInteraction,
4647
'FwFMLayer': FwFMLayer,
4748
'softmax': softmax,
49+
'FEFMLayer': FEFMLayer,
50+
'reduce_sum': reduce_sum
4851
}

deepctr/layers/interaction.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1409,3 +1409,88 @@ def get_config(self):
14091409
'regularizer': self.regularizer
14101410
})
14111411
return config
1412+
1413+
1414+
class FEFMLayer(Layer):
1415+
"""Field-Embedded Factorization Machines
1416+
1417+
Input shape
1418+
- 3D tensor with shape: ``(batch_size,field_size,embedding_size)``.
1419+
1420+
Output shape
1421+
- 2D tensor with shape:
1422+
``(batch_size, (num_fields * (num_fields-1))/2)`` # concatenated FEFM interaction embeddings
1423+
1424+
Arguments
1425+
- **num_fields** : integer for number of fields
1426+
- **embedding_size** : integer for embedding dimension
1427+
- **regularizer** : L2 regularizer weight for the field pair matrix embeddings parameters of FEFM
1428+
1429+
References
1430+
- [Field-Embedded Factorization Machines for Click-through Rate Prediction]
1431+
https://arxiv.org/pdf/2009.09931.pdf
1432+
"""
1433+
1434+
def __init__(self, num_fields, embedding_size, regularizer, **kwargs):
1435+
self.num_fields = num_fields
1436+
self.embedding_size = embedding_size
1437+
self.regularizer = regularizer
1438+
super(FEFMLayer, self).__init__(**kwargs)
1439+
1440+
def build(self, input_shape):
1441+
if len(input_shape) != 3:
1442+
raise ValueError("Unexpected inputs dimensions % d,\
1443+
expect to be 3 dimensions" % (len(input_shape)))
1444+
1445+
if input_shape[1] != self.num_fields:
1446+
raise ValueError("Mismatch in number of fields {} and \
1447+
concatenated embeddings dims {}".format(self.num_fields, input_shape[2]))
1448+
1449+
self.field_embeddings = {}
1450+
1451+
for fi, fj in itertools.combinations(range(self.num_fields), 2):
1452+
field_pair_id = str(fi) + "-" + str(fj)
1453+
self.field_embeddings[field_pair_id] = self.add_weight(name='field_embeddings' + field_pair_id,
1454+
shape=(self.embedding_size, self.embedding_size),
1455+
initializer=TruncatedNormal(),
1456+
regularizer=l2(self.regularizer),
1457+
trainable=True)
1458+
1459+
super(FEFMLayer, self).build(input_shape) # Be sure to call this somewhere!
1460+
1461+
def call(self, inputs, **kwargs):
1462+
if K.ndim(inputs) != 3:
1463+
raise ValueError(
1464+
"Unexpected inputs dimensions %d, expect to be 3 dimensions"
1465+
% (K.ndim(inputs)))
1466+
1467+
if inputs.shape[1] != self.num_fields:
1468+
raise ValueError("Mismatch in number of fields {} and \
1469+
concatenated embeddings dims {}".format(self.num_fields, inputs.shape[1]))
1470+
1471+
pairwise_inner_prods = []
1472+
for fi, fj in itertools.combinations(range(self.num_fields), 2):
1473+
field_pair_id = str(fi) + "-" + str(fj)
1474+
feat_embed_i = tf.squeeze(inputs[0:, fi:fi + 1, 0:], axis=1)
1475+
feat_embed_j = tf.squeeze(inputs[0:, fj:fj + 1, 0:], axis=1)
1476+
field_pair_embed_ij = self.field_embeddings[field_pair_id]
1477+
1478+
feat_embed_i_tr = tf.matmul(feat_embed_i, field_pair_embed_ij + tf.transpose(field_pair_embed_ij))
1479+
1480+
f = batch_dot(feat_embed_i_tr, feat_embed_j, axes=1)
1481+
pairwise_inner_prods.append(f)
1482+
1483+
concat_vec = tf.concat(pairwise_inner_prods, axis=1)
1484+
return concat_vec
1485+
1486+
def compute_output_shape(self, input_shape):
1487+
return (None, (self.num_fields * (self.num_fields-1))/2)
1488+
1489+
def get_config(self):
1490+
config = super(FEFMLayer, self).get_config().copy()
1491+
config.update({
1492+
'num_fields': self.num_fields,
1493+
'regularizer': self.regularizer,
1494+
'embedding_size': self.embedding_size
1495+
})
1496+
return config

deepctr/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .flen import FLEN
2222
from .fwfm import FwFM
2323
from .bst import BST
24+
from .deepfefm import DeepFEFM
2425

2526
__all__ = ["AFM", "CCPM", "DCN", "IFM", "DIFM", "DCNMix", "MLR", "DeepFM", "MLR", "NFM", "DIN", "DIEN", "FNN", "PNN",
26-
"WDL", "xDeepFM", "AutoInt", "ONN", "FGCNN", "DSIN", "FiBiNET", 'FLEN', "FwFM", "BST"]
27+
"WDL", "xDeepFM", "AutoInt", "ONN", "FGCNN", "DSIN", "FiBiNET", 'FLEN', "FwFM", "BST", "DeepFEFM"]

deepctr/models/deepfefm.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# -*- coding:utf-8 -*-
2+
"""
3+
Author:
4+
Harshit Pande
5+
6+
Reference:
7+
[1] Field-Embedded Factorization Machines for Click-through Rate Prediction]
8+
(https://arxiv.org/pdf/2009.09931.pdf)
9+
10+
this file also supports all the possible Ablation studies for reproducibility
11+
12+
"""
13+
14+
from itertools import chain
15+
16+
import tensorflow as tf
17+
18+
from ..feature_column import input_from_feature_columns, get_linear_logit, build_input_features, DEFAULT_GROUP_NAME
19+
from ..layers.core import PredictionLayer, DNN
20+
from ..layers.interaction import FEFMLayer
21+
from ..layers.utils import concat_func, combined_dnn_input, reduce_sum
22+
23+
24+
def DeepFEFM(linear_feature_columns, dnn_feature_columns, embedding_size=48, use_fefm=True,
25+
dnn_hidden_units=(1024, 1024, 1024), l2_reg_linear=0.000001, l2_reg_embedding_feat=0.00001,
26+
l2_reg_embedding_field=0.0000001, l2_reg_dnn=0, seed=1024, dnn_dropout=0.2, exclude_feature_embed_in_dnn=False,
27+
use_linear=True, use_fefm_embed_in_dnn=True, dnn_activation='relu', dnn_use_bn=False, task='binary'):
28+
"""Instantiates the DeepFEFM Network architecture or the shallow FEFM architecture (Ablation studies supported)
29+
30+
:param linear_feature_columns: An iterable containing all the features used by linear part of the model.
31+
:param dnn_feature_columns: An iterable containing all the features used by deep part of the model.
32+
:param fm_group: list, group_name of features that will be used to do feature interactions.
33+
:param embedding_size: positive integer,sparse feature embedding_size
34+
:param use_fefm: bool,use FEFM logit or not (doesn't effect FEFM embeddings in DNN, controls only the use of final FEFM logit)
35+
:param dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of DNN
36+
:param l2_reg_linear: float. L2 regularizer strength applied to linear part
37+
:param l2_reg_embedding_feat: float. L2 regularizer strength applied to embedding vector of features
38+
:param l2_reg_embedding_field: float, L2 regularizer to field embeddings
39+
:param l2_reg_dnn: float. L2 regularizer strength applied to DNN
40+
:param seed: integer ,to use as random seed.
41+
:param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate.
42+
:param exclude_feature_embed_in_dnn: bool, used in ablation studies for removing feature embeddings in DNN
43+
:param use_linear: bool, used in ablation studies
44+
:param use_fefm_embed_in_dnn: bool, True if FEFM interaction embeddings are to be used in FEFM (set False for Ablation)
45+
:param dnn_activation: Activation function to use in DNN
46+
:param dnn_use_bn: bool. Whether use BatchNormalization before activation or not in DNN
47+
:param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss
48+
:return: A Keras model instance.
49+
"""
50+
51+
features = build_input_features(linear_feature_columns + dnn_feature_columns)
52+
53+
inputs_list = list(features.values())
54+
55+
linear_logit = get_linear_logit(features, linear_feature_columns, l2_reg=l2_reg_linear, seed=seed, prefix='linear')
56+
57+
group_embedding_dict, dense_value_list = input_from_feature_columns(features, dnn_feature_columns,
58+
l2_reg_embedding_feat,
59+
seed, support_group=True)
60+
61+
fefm_interaction_embedding = concat_func([FEFMLayer(num_fields=len(v), embedding_size=embedding_size,
62+
regularizer=l2_reg_embedding_field)(concat_func(v, axis=1))
63+
for k, v in group_embedding_dict.items() if k in [DEFAULT_GROUP_NAME]], axis=1)
64+
65+
dnn_input = combined_dnn_input(list(chain.from_iterable(group_embedding_dict.values())), dense_value_list)
66+
67+
# if use_fefm_embed_in_dnn is set to False it is Ablation4 (Use false only for Ablation)
68+
if use_fefm_embed_in_dnn:
69+
if exclude_feature_embed_in_dnn:
70+
# Ablation3: remove feature vector embeddings from the DNN input
71+
dnn_input = fefm_interaction_embedding
72+
else:
73+
# No ablation
74+
dnn_input = concat_func([dnn_input, fefm_interaction_embedding], axis=1)
75+
76+
dnn_out = DNN(dnn_hidden_units, dnn_activation, l2_reg_dnn, dnn_dropout, dnn_use_bn, seed=seed)(dnn_input)
77+
78+
dnn_logit = tf.keras.layers.Dense(
79+
1, use_bias=False, kernel_initializer=tf.keras.initializers.glorot_normal(seed))(dnn_out)
80+
81+
fefm_logit = tf.keras.layers.Lambda(lambda x: reduce_sum(x, axis=1, keep_dims=True))(fefm_interaction_embedding)
82+
83+
if len(dnn_hidden_units) == 0 and use_fefm is False and use_linear is True: # only linear
84+
final_logit = linear_logit
85+
elif len(dnn_hidden_units) == 0 and use_fefm is True and use_linear is True: # linear + FEFM
86+
final_logit = tf.keras.layers.add([linear_logit, fefm_logit])
87+
elif len(dnn_hidden_units) > 0 and use_fefm is False and use_linear is True: # linear + Deep # Ablation1
88+
final_logit = tf.keras.layers.add([linear_logit, dnn_logit])
89+
elif len(dnn_hidden_units) > 0 and use_fefm is True and use_linear is True: # linear + FEFM + Deep
90+
final_logit = tf.keras.layers.add([linear_logit, fefm_logit, dnn_logit])
91+
elif len(dnn_hidden_units) == 0 and use_fefm is True and use_linear is False: # only FEFM (shallow)
92+
final_logit = fefm_logit
93+
elif len(dnn_hidden_units) > 0 and use_fefm is False and use_linear is False: # only Deep
94+
final_logit = dnn_logit
95+
elif len(dnn_hidden_units) > 0 and use_fefm is True and use_linear is False: # FEFM + Deep # Ablation2
96+
final_logit = tf.keras.layers.add([fefm_logit, dnn_logit])
97+
else:
98+
raise NotImplementedError
99+
100+
output = PredictionLayer(task)(final_logit)
101+
model = tf.keras.models.Model(inputs=inputs_list, outputs=output)
102+
return model

0 commit comments

Comments
 (0)