Skip to content

Commit c13aba6

Browse files
Heyi007何意shenweichen
authored
Add EDCN model.
* feat: Add EDCN model. Co-authored-by: 何意 <[email protected]> Co-authored-by: 浅梦 <[email protected]>
1 parent ec78b9b commit c13aba6

File tree

7 files changed

+278
-7
lines changed

7 files changed

+278
-7
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ Introduction](https://zhuanlan.zhihu.com/p/53231955)) and [welcome to join us!](
6666
| ESMM | [SIGIR 2018][Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate](https://arxiv.org/abs/1804.07931) |
6767
| MMOE | [KDD 2018][Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts](https://dl.acm.org/doi/abs/10.1145/3219819.3220007) |
6868
| PLE | [RecSys 2020][Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personalized Recommendations](https://dl.acm.org/doi/10.1145/3383313.3412236) |
69+
| EDCN | [KDD 2021][Enhancing Explicit and Implicit Feature Interactions via Information Sharing for Parallel Deep CTR Models](https://dlp-kdd.github.io/assets/pdf/DLP-KDD_2021_paper_12.pdf) |
6970

7071
## Citation
7172

deepctr/layers/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import tensorflow as tf
22

33
from .activation import Dice
4-
from .core import DNN, LocalActivationUnit, PredictionLayer
4+
from .core import DNN, LocalActivationUnit, PredictionLayer, RegulationLayer
55
from .interaction import (CIN, FM, AFMLayer, BiInteractionPooling, CrossNet, CrossNetMix,
66
InnerProductLayer, InteractingLayer,
77
OutterProductLayer, FGCNNLayer, SENETLayer, BilinearInteraction,
8-
FieldWiseBiInteraction, FwFMLayer, FEFMLayer)
8+
FieldWiseBiInteraction, FwFMLayer, FEFMLayer, BridgeLayer)
99
from .normalization import LayerNormalization
1010
from .sequence import (AttentionSequencePoolingLayer, BiasEncoding, BiLSTM,
1111
KMaxPooling, SequencePoolingLayer, WeightedSequenceLayer,
@@ -28,6 +28,7 @@
2828
'SequencePoolingLayer': SequencePoolingLayer,
2929
'AttentionSequencePoolingLayer': AttentionSequencePoolingLayer,
3030
'CIN': CIN,
31+
'RegulationLayer': RegulationLayer,
3132
'InteractingLayer': InteractingLayer,
3233
'LayerNormalization': LayerNormalization,
3334
'BiLSTM': BiLSTM,
@@ -48,5 +49,6 @@
4849
'softmax': softmax,
4950
'FEFMLayer': FEFMLayer,
5051
'reduce_sum': reduce_sum,
51-
'PositionEncoding':PositionEncoding
52+
'PositionEncoding': PositionEncoding,
53+
'BridgeLayer': BridgeLayer
5254
}

deepctr/layers/core.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
from tensorflow.python.keras import backend as K
1111

1212
try:
13-
from tensorflow.python.ops.init_ops_v2 import Zeros, glorot_normal
13+
from tensorflow.python.ops.init_ops_v2 import Zeros, Ones, glorot_normal
1414
except ImportError:
15-
from tensorflow.python.ops.init_ops import Zeros, glorot_normal_initializer as glorot_normal
15+
from tensorflow.python.ops.init_ops import Zeros, Ones, glorot_normal_initializer as glorot_normal
1616

1717
from tensorflow.python.keras.layers import Layer, Dropout
1818

@@ -265,3 +265,59 @@ def get_config(self, ):
265265
config = {'task': self.task, 'use_bias': self.use_bias}
266266
base_config = super(PredictionLayer, self).get_config()
267267
return dict(list(base_config.items()) + list(config.items()))
268+
269+
270+
class RegulationLayer(Layer):
271+
"""Regulation module used in EDCN.
272+
273+
Input shape
274+
- A list of 3D tensor with shape: ``(batch_size,1,embedding_size)``.
275+
276+
Output shape
277+
- 2D tensor with shape: ``(batch_size, embedding_size * field_num)``.
278+
279+
Arguments
280+
- **tau** : Positive float, the temperature coefficient to control
281+
distribution of field-wise gating unit.
282+
283+
- **seed** : A Python integer to use as random seed.
284+
285+
References
286+
- [Enhancing Explicit and Implicit Feature Interactions via Information Sharing for Parallel Deep CTR Models.](https://dlp-kdd.github.io/assets/pdf/DLP-KDD_2021_paper_12.pdf)
287+
"""
288+
289+
def __init__(self, tau=0.1, **kwargs):
290+
if tau == 0:
291+
raise ValueError("RegulationLayer tau can not be zero.")
292+
self.tau = 1.0 / tau
293+
super(RegulationLayer, self).__init__(**kwargs)
294+
295+
def build(self, input_shape):
296+
self.field_num = int(input_shape[1])
297+
self.embedding_size = int(input_shape[2])
298+
self.g = self.add_weight(
299+
shape=(1, self.field_num, 1),
300+
initializer=Ones(),
301+
name=self.name + '_field_weight')
302+
303+
# Be sure to call this somewhere!
304+
super(RegulationLayer, self).build(input_shape)
305+
306+
def call(self, inputs, **kwargs):
307+
308+
if K.ndim(inputs) != 3:
309+
raise ValueError(
310+
"Unexpected inputs dimensions %d, expect to be 3 dimensions" % (K.ndim(inputs)))
311+
312+
feild_gating_score = tf.nn.softmax(self.g * self.tau, 1)
313+
E = inputs * feild_gating_score
314+
return tf.reshape(E, [-1, self.field_num * self.embedding_size])
315+
316+
def compute_output_shape(self, input_shape):
317+
return (None, self.field_num * self.embedding_size)
318+
319+
def get_config(self):
320+
config = {'tau': self.tau}
321+
base_config = super(RegulationLayer, self).get_config()
322+
base_config.update(config)
323+
return base_config

deepctr/layers/interaction.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
44
Authors:
55
Weichen Shen,[email protected],
6-
Harshit Pande
6+
Harshit Pande,
7+
78
89
"""
910

@@ -26,6 +27,7 @@
2627

2728
from .activation import activation_layer
2829
from .utils import concat_func, reduce_sum, softmax, reduce_mean
30+
from .core import DNN
2931

3032

3133
class AFMLayer(Layer):
@@ -1489,3 +1491,74 @@ def get_config(self):
14891491
'regularizer': self.regularizer,
14901492
})
14911493
return config
1494+
1495+
1496+
class BridgeLayer(Layer): # ridge
1497+
"""AttentionPoolingLayer layer used in EDCN
1498+
1499+
Input shape
1500+
- A list of 3D tensor with shape: ``(batch_size,1,embedding_size)``. Its length is ``number of subnetworks``.
1501+
1502+
Output shape
1503+
- 2D tensor with shape: ``(batch_size, embedding_size)``.
1504+
1505+
Arguments
1506+
- **activation**: Activation function to use.
1507+
1508+
- **l2_reg**: float between 0 and 1. L2 regularizer strength applied to the kernel weights matrix.
1509+
1510+
- **seed**: A Python integer to use as random seed.
1511+
1512+
References
1513+
- [Enhancing Explicit and Implicit Feature Interactions via Information Sharing for Parallel Deep CTR Models.](https://dlp-kdd.github.io/assets/pdf/DLP-KDD_2021_paper_12.pdf)
1514+
1515+
"""
1516+
1517+
def __init__(self, bridge_type='attention_pooling', activation='relu', l2_reg=0, seed=1024, **kwargs):
1518+
self.bridge_type = bridge_type
1519+
self.activation = activation
1520+
self.l2_reg = l2_reg
1521+
self.seed = seed
1522+
1523+
super(BridgeLayer, self).__init__(**kwargs)
1524+
1525+
def build(self, input_shape):
1526+
if not isinstance(input_shape, list) or len(input_shape) < 2:
1527+
raise ValueError(
1528+
'A `AttentionPoolingLayer` layer should be called '
1529+
'on a list of at least 2 inputs')
1530+
1531+
self.dnn_dim = int(input_shape[0][-1])
1532+
1533+
self.dense = Dense(self.dnn_dim, self.activation)
1534+
self.dense_x = DNN([self.dnn_dim, self.dnn_dim], output_activation='softmax')
1535+
self.dense_h = DNN([self.dnn_dim, self.dnn_dim], output_activation='softmax')
1536+
1537+
super(BridgeLayer, self).build(input_shape) # Be sure to call this somewhere!
1538+
1539+
def call(self, inputs, **kwargs):
1540+
x, h = inputs
1541+
if self.bridge_type == "pointwise_addition":
1542+
return x + h
1543+
elif self.bridge_type == "hadamard_product":
1544+
return x * h
1545+
elif self.bridge_type == "concatenation":
1546+
return self.dense(tf.concat(inputs, axis=-1))
1547+
elif self.bridge_type == "attention_pooling":
1548+
a_x = self.dense_x(x)
1549+
a_h = self.dense_h(h)
1550+
return a_x * x + a_h * h
1551+
1552+
def compute_output_shape(self, input_shape):
1553+
return (None, self.dnn_dim)
1554+
1555+
def get_config(self):
1556+
base_config = super(BridgeLayer, self).get_config().copy()
1557+
config = {
1558+
'bridge_type': self.bridge_type,
1559+
'l2_reg': self.l2_reg,
1560+
'activation': self.activation,
1561+
'seed': self.seed
1562+
}
1563+
config.update(base_config)
1564+
return config

deepctr/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from .sequence import DIN, DIEN, DSIN, BST
2121
from .wdl import WDL
2222
from .xdeepfm import xDeepFM
23+
from .edcn import EDCN
2324

2425
__all__ = ["AFM", "CCPM", "DCN", "IFM", "DIFM", "DCNMix", "MLR", "DeepFM", "MLR", "NFM", "DIN", "DIEN", "FNN", "PNN",
2526
"WDL", "xDeepFM", "AutoInt", "ONN", "FGCNN", "DSIN", "FiBiNET", 'FLEN', "FwFM", "BST", "DeepFEFM",
26-
"SharedBottom", "ESMM", "MMOE", "PLE"]
27+
"SharedBottom", "ESMM", "MMOE", "PLE", 'EDCN']

deepctr/models/edcn.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# -*- coding:utf-8 -*-
2+
"""
3+
Author:
4+
5+
6+
Reference:
7+
[1] Chen, B., Wang, Y., Liu, et al. Enhancing Explicit and Implicit Feature Interactions via Information Sharing for Parallel Deep CTR Models. CIKM, 2021, October (https://dlp-kdd.github.io/assets/pdf/DLP-KDD_2021_paper_12.pdf)
8+
"""
9+
import tensorflow as tf
10+
from tensorflow.python.keras.layers import Dense, Lambda, Reshape, Concatenate
11+
from tensorflow.python.keras.models import Model
12+
13+
from ..feature_column import build_input_features, get_linear_logit, input_from_feature_columns
14+
from ..layers.core import PredictionLayer, DNN, RegulationLayer
15+
from ..layers.interaction import CrossNet, BridgeLayer
16+
from ..layers.utils import add_func, concat_func
17+
18+
19+
def EDCN(linear_feature_columns,
20+
dnn_feature_columns,
21+
bridge_type='attention_pooling',
22+
tau=0.1,
23+
use_dense_features=True,
24+
cross_num=2,
25+
cross_parameterization='vector',
26+
l2_reg_linear=1e-5,
27+
l2_reg_embedding=1e-5,
28+
l2_reg_cross=1e-5,
29+
l2_reg_dnn=0,
30+
seed=10000,
31+
dnn_dropout=0,
32+
dnn_use_bn=False,
33+
dnn_activation='relu',
34+
task='binary'):
35+
"""Instantiates the Enhanced Deep&Cross Network architecture.
36+
:param linear_feature_columns: An iterable containing all the features used by linear part of the model.
37+
:param dnn_feature_columns: An iterable containing all the features used by deep part of the model.
38+
:param bridge_type: The type of bridge interaction, one of 'pointwise_addition', 'hadamard_product', 'concatenation', 'attention_pooling'
39+
:param tau: Positive float, the temperature coefficient to control distribution of field-wise gating unit
40+
:param use_dense_features: Whether to use dense features, if True, dense feature will be projected to sparse embedding space
41+
:param cross_num: positive integet,cross layer number
42+
:param cross_parameterization: str, ``"vector"`` or ``"matrix"``, how to parameterize the cross network.
43+
:param l2_reg_linear: float. L2 regularizer strength applied to linear part
44+
:param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector
45+
:param l2_reg_cross: float. L2 regularizer strength applied to cross net
46+
:param l2_reg_dnn: float. L2 regularizer strength applied to DNN
47+
:param seed: integer ,to use as random seed.
48+
:param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate.
49+
:param dnn_use_bn: bool. Whether use BatchNormalization before activation or not DNN
50+
:param dnn_activation: Activation function to use in DNN
51+
:param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss
52+
:return: A Keras model instance.
53+
54+
"""
55+
if cross_num == 0:
56+
raise ValueError("Cross layer num must > 0")
57+
58+
print('EDCN brige type: ', bridge_type)
59+
60+
features = build_input_features(dnn_feature_columns)
61+
inputs_list = list(features.values())
62+
63+
linear_logit = get_linear_logit(features,
64+
linear_feature_columns,
65+
seed=seed,
66+
prefix='linear',
67+
l2_reg=l2_reg_linear)
68+
69+
sparse_embedding_list, dense_value_list = input_from_feature_columns(
70+
features, dnn_feature_columns, l2_reg_embedding, seed)
71+
72+
# project dense value to sparse embedding space, generate a new field feature
73+
sparse_embedding_dim = int(sparse_embedding_list[0].shape[-1])
74+
if use_dense_features:
75+
dense_value_feild = concat_func(dense_value_list)
76+
dense_value_feild = Dense(sparse_embedding_dim, dnn_activation)(dense_value_feild)
77+
dense_value_feild = Lambda(lambda x: tf.expand_dims(x, axis=1))(dense_value_feild)
78+
sparse_embedding_list.append(dense_value_feild)
79+
80+
deep_in = concat_func(sparse_embedding_list, axis=1)
81+
cross_in = concat_func(sparse_embedding_list, axis=1)
82+
field_size = len(sparse_embedding_list)
83+
cross_dim = field_size * int(cross_in[0].shape[-1])
84+
85+
for i in range(cross_num):
86+
deep_in = RegulationLayer(tau)(deep_in)
87+
cross_in = RegulationLayer(tau)(cross_in)
88+
cross_out = CrossNet(1, parameterization=cross_parameterization,
89+
l2_reg=l2_reg_cross)(deep_in)
90+
deep_out = DNN([cross_dim], dnn_activation, l2_reg_dnn,
91+
dnn_dropout, dnn_use_bn, seed=seed)(cross_in)
92+
93+
bridge_out = BridgeLayer(bridge_type)([cross_out, deep_out])
94+
bridge_out_list = Reshape([field_size, sparse_embedding_dim])(bridge_out)
95+
96+
deep_in = bridge_out_list
97+
cross_in = bridge_out_list
98+
99+
stack_out = Concatenate()([cross_out, deep_out, bridge_out])
100+
final_logit = Dense(1, use_bias=False)(stack_out)
101+
102+
final_logit = add_func([final_logit, linear_logit])
103+
output = PredictionLayer(task)(final_logit)
104+
105+
model = Model(inputs=inputs_list, outputs=final_logit)
106+
107+
return model

tests/models/EDCN_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import pytest
2+
import tensorflow as tf
3+
4+
from deepctr.models import EDCN
5+
from ..utils import check_model, get_test_data, SAMPLE_SIZE, get_test_data_estimator, check_estimator, \
6+
TEST_Estimator
7+
8+
9+
@pytest.mark.parametrize(
10+
'bridge_type, tau, use_dense_features, cross_num, cross_parameterization, sparse_feature_num',
11+
[
12+
('pointwise_addition', 1, True, 2, 'vector', 3),
13+
('hadamard_product', 1, False, 2, 'vector', 4),
14+
('concatenation', 1, True, 3, 'vector', 5),
15+
('attention_pooling', 1, True, 2, 'matrix', 6),
16+
]
17+
)
18+
def test_EDCN(bridge_type, tau, use_dense_features, cross_num, cross_parameterization, sparse_feature_num):
19+
model_name = "EDCN"
20+
21+
sample_size = SAMPLE_SIZE
22+
x, y, feature_columns = get_test_data(sample_size, sparse_feature_num=sparse_feature_num,
23+
dense_feature_num=sparse_feature_num)
24+
25+
model = EDCN(feature_columns, feature_columns,
26+
bridge_type, tau, use_dense_features, cross_num, cross_parameterization)
27+
check_model(model, model_name, x, y)
28+
29+
30+
if __name__ == "__main__":
31+
pass

0 commit comments

Comments
 (0)