1
1
# -*- coding:utf-8 -*-
2
2
"""
3
3
4
- Author:
5
-
4
+ Authors:
5
+
6
+ Harshit Pande
6
7
7
8
"""
8
9
11
12
import tensorflow as tf
12
13
from tensorflow .python .keras import backend as K
13
14
from tensorflow .python .keras .initializers import (Zeros , glorot_normal ,
14
- glorot_uniform )
15
+ glorot_uniform , TruncatedNormal )
15
16
from tensorflow .python .keras .layers import Layer
16
17
from tensorflow .python .keras .regularizers import l2
18
+ from tensorflow .python .keras .backend import batch_dot
17
19
from tensorflow .python .layers import utils
18
20
19
21
from .activation import activation_layer
@@ -1052,7 +1054,7 @@ class FieldWiseBiInteraction(Layer):
1052
1054
1053
1055
Output shape
1054
1056
- 2D tensor with shape: ``(batch_size,embedding_size)``.
1055
-
1057
+
1056
1058
Arguments
1057
1059
- **use_bias** : Boolean, if use bias.
1058
1060
- **seed** : A Python integer to use as random seed.
@@ -1062,7 +1064,7 @@ class FieldWiseBiInteraction(Layer):
1062
1064
1063
1065
"""
1064
1066
1065
- def __init__ (self ,use_bias = True , seed = 1024 , ** kwargs ):
1067
+ def __init__ (self , use_bias = True , seed = 1024 , ** kwargs ):
1066
1068
self .use_bias = use_bias
1067
1069
self .seed = seed
1068
1070
@@ -1167,3 +1169,80 @@ def get_config(self, ):
1167
1169
config = {'use_bias' : self .use_bias , 'seed' : self .seed }
1168
1170
base_config = super (FieldWiseBiInteraction , self ).get_config ()
1169
1171
return dict (list (base_config .items ()) + list (config .items ()))
1172
+
1173
+
1174
+ class FwFM (Layer ):
1175
+ """Field-weighted Factorization Machines
1176
+
1177
+ Input shape
1178
+ - 3D tensor with shape: ``(batch_size,field_size,embedding_size)``.
1179
+
1180
+ Output shape
1181
+ - 2D tensor with shape: ``(batch_size, 1)``.
1182
+
1183
+ Arguments
1184
+ - **num_fields** : integer for number of fields
1185
+ - **regularizer** : L2 regularizer weight for the field strength parameters of FwFM
1186
+
1187
+ References
1188
+ - [Field-weighted Factorization Machines for Click-Through Rate Prediction in Display Advertising]
1189
+ https://arxiv.org/pdf/1806.03514.pdf
1190
+ """
1191
+
1192
+ def __init__ (self , num_fields = 4 , regularizer = 0.000001 , ** kwargs ):
1193
+ self .num_fields = num_fields
1194
+ self .regularizer = regularizer
1195
+ super (FwFM , self ).__init__ (** kwargs )
1196
+
1197
+ def build (self , input_shape ):
1198
+ if len (input_shape ) != 3 :
1199
+ raise ValueError ("Unexpected inputs dimensions % d,\
1200
+ expect to be 3 dimensions" % (len (input_shape )))
1201
+
1202
+ if input_shape [1 ] != self .num_fields :
1203
+ raise ValueError ("Mismatch in number of fields {} and \
1204
+ concatenated embeddings dims {}" .format (self .num_fields , input_shape [1 ]))
1205
+
1206
+ self .field_strengths = self .add_weight (name = 'field_pair_strengths' ,
1207
+ shape = (self .num_fields , self .num_fields ),
1208
+ initializer = TruncatedNormal (),
1209
+ regularizer = l2 (self .regularizer ),
1210
+ trainable = True )
1211
+
1212
+ super (FwFM , self ).build (input_shape ) # Be sure to call this somewhere!
1213
+
1214
+ def call (self , inputs , ** kwargs ):
1215
+ if K .ndim (inputs ) != 3 :
1216
+ raise ValueError (
1217
+ "Unexpected inputs dimensions %d, expect to be 3 dimensions"
1218
+ % (K .ndim (inputs )))
1219
+
1220
+ if inputs .shape [1 ] != self .num_fields :
1221
+ raise ValueError ("Mismatch in number of fields {} and \
1222
+ concatenated embeddings dims {}" .format (self .num_fields , inputs .shape [1 ]))
1223
+
1224
+ pairwise_inner_prods = []
1225
+ for fi , fj in itertools .combinations (range (self .num_fields ), 2 ):
1226
+ # get field strength for pair fi and fj
1227
+ r_ij = self .field_strengths [fi , fj ]
1228
+
1229
+ # get embeddings for the features of both the fields
1230
+ feat_embed_i = tf .squeeze (inputs [0 :, fi :fi + 1 , 0 :], axis = 1 )
1231
+ feat_embed_j = tf .squeeze (inputs [0 :, fj :fj + 1 , 0 :], axis = 1 )
1232
+
1233
+ f = tf .scalar_mul (r_ij , batch_dot (feat_embed_i , feat_embed_j , axes = 1 ))
1234
+ pairwise_inner_prods .append (f )
1235
+
1236
+ sum_ = tf .add_n (pairwise_inner_prods )
1237
+ return sum_
1238
+
1239
+ def compute_output_shape (self , input_shape ):
1240
+ return (None , 1 )
1241
+
1242
+ def get_config (self ):
1243
+ config = super (FwFM , self ).get_config ().copy ()
1244
+ config .update ({
1245
+ 'num_fields' : self .num_fields ,
1246
+ 'regularizer' : self .regularizer
1247
+ })
1248
+ return config
0 commit comments