|
| 1 | +# -*- coding:utf-8 -*- |
| 2 | +""" |
| 3 | +Author: |
| 4 | + Harshit Pande |
| 5 | +
|
| 6 | +Reference: |
| 7 | + [1] Field-Embedded Factorization Machines for Click-through Rate Prediction] |
| 8 | + (https://arxiv.org/pdf/2009.09931.pdf) |
| 9 | +
|
| 10 | + this file also supports all the possible Ablation studies for reproducibility |
| 11 | +
|
| 12 | +""" |
| 13 | + |
| 14 | +from itertools import chain |
| 15 | + |
| 16 | +import tensorflow as tf |
| 17 | + |
| 18 | +from ..feature_column import input_from_feature_columns, get_linear_logit, build_input_features, DEFAULT_GROUP_NAME |
| 19 | +from ..layers.core import PredictionLayer, DNN |
| 20 | +from ..layers.interaction import FEFMLayer |
| 21 | +from ..layers.utils import concat_func, combined_dnn_input, reduce_sum |
| 22 | + |
| 23 | + |
| 24 | +def DeepFEFM(linear_feature_columns, dnn_feature_columns, embedding_size=48, use_fefm=True, |
| 25 | + dnn_hidden_units=(1024, 1024, 1024), l2_reg_linear=0.000001, l2_reg_embedding_feat=0.00001, |
| 26 | + l2_reg_embedding_field=0.0000001, l2_reg_dnn=0, seed=1024, dnn_dropout=0.2, exclude_feature_embed_in_dnn=False, |
| 27 | + use_linear=True, use_fefm_embed_in_dnn=True, dnn_activation='relu', dnn_use_bn=False, task='binary'): |
| 28 | + """Instantiates the DeepFEFM Network architecture or the shallow FEFM architecture (Ablation studies supported) |
| 29 | +
|
| 30 | + :param linear_feature_columns: An iterable containing all the features used by linear part of the model. |
| 31 | + :param dnn_feature_columns: An iterable containing all the features used by deep part of the model. |
| 32 | + :param fm_group: list, group_name of features that will be used to do feature interactions. |
| 33 | + :param embedding_size: positive integer,sparse feature embedding_size |
| 34 | + :param use_fefm: bool,use FEFM logit or not (doesn't effect FEFM embeddings in DNN, controls only the use of final FEFM logit) |
| 35 | + :param dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of DNN |
| 36 | + :param l2_reg_linear: float. L2 regularizer strength applied to linear part |
| 37 | + :param l2_reg_embedding_feat: float. L2 regularizer strength applied to embedding vector of features |
| 38 | + :param l2_reg_embedding_field: float, L2 regularizer to field embeddings |
| 39 | + :param l2_reg_dnn: float. L2 regularizer strength applied to DNN |
| 40 | + :param seed: integer ,to use as random seed. |
| 41 | + :param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate. |
| 42 | + :param exclude_feature_embed_in_dnn: bool, used in ablation studies for removing feature embeddings in DNN |
| 43 | + :param use_linear: bool, used in ablation studies |
| 44 | + :param use_fefm_embed_in_dnn: bool, True if FEFM interaction embeddings are to be used in FEFM (set False for Ablation) |
| 45 | + :param dnn_activation: Activation function to use in DNN |
| 46 | + :param dnn_use_bn: bool. Whether use BatchNormalization before activation or not in DNN |
| 47 | + :param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss |
| 48 | + :return: A Keras model instance. |
| 49 | + """ |
| 50 | + |
| 51 | + features = build_input_features(linear_feature_columns + dnn_feature_columns) |
| 52 | + |
| 53 | + inputs_list = list(features.values()) |
| 54 | + |
| 55 | + linear_logit = get_linear_logit(features, linear_feature_columns, l2_reg=l2_reg_linear, seed=seed, prefix='linear') |
| 56 | + |
| 57 | + group_embedding_dict, dense_value_list = input_from_feature_columns(features, dnn_feature_columns, |
| 58 | + l2_reg_embedding_feat, |
| 59 | + seed, support_group=True) |
| 60 | + |
| 61 | + fefm_interaction_embedding = concat_func([FEFMLayer(num_fields=len(v), embedding_size=embedding_size, |
| 62 | + regularizer=l2_reg_embedding_field)(concat_func(v, axis=1)) |
| 63 | + for k, v in group_embedding_dict.items() if k in [DEFAULT_GROUP_NAME]], axis=1) |
| 64 | + |
| 65 | + dnn_input = combined_dnn_input(list(chain.from_iterable(group_embedding_dict.values())), dense_value_list) |
| 66 | + |
| 67 | + # if use_fefm_embed_in_dnn is set to False it is Ablation4 (Use false only for Ablation) |
| 68 | + if use_fefm_embed_in_dnn: |
| 69 | + if exclude_feature_embed_in_dnn: |
| 70 | + # Ablation3: remove feature vector embeddings from the DNN input |
| 71 | + dnn_input = fefm_interaction_embedding |
| 72 | + else: |
| 73 | + # No ablation |
| 74 | + dnn_input = concat_func([dnn_input, fefm_interaction_embedding], axis=1) |
| 75 | + |
| 76 | + dnn_out = DNN(dnn_hidden_units, dnn_activation, l2_reg_dnn, dnn_dropout, dnn_use_bn, seed=seed)(dnn_input) |
| 77 | + |
| 78 | + dnn_logit = tf.keras.layers.Dense( |
| 79 | + 1, use_bias=False, kernel_initializer=tf.keras.initializers.glorot_normal(seed))(dnn_out) |
| 80 | + |
| 81 | + fefm_logit = tf.keras.layers.Lambda(lambda x: reduce_sum(x, axis=1, keep_dims=True))(fefm_interaction_embedding) |
| 82 | + |
| 83 | + if len(dnn_hidden_units) == 0 and use_fefm is False and use_linear is True: # only linear |
| 84 | + final_logit = linear_logit |
| 85 | + elif len(dnn_hidden_units) == 0 and use_fefm is True and use_linear is True: # linear + FEFM |
| 86 | + final_logit = tf.keras.layers.add([linear_logit, fefm_logit]) |
| 87 | + elif len(dnn_hidden_units) > 0 and use_fefm is False and use_linear is True: # linear + Deep # Ablation1 |
| 88 | + final_logit = tf.keras.layers.add([linear_logit, dnn_logit]) |
| 89 | + elif len(dnn_hidden_units) > 0 and use_fefm is True and use_linear is True: # linear + FEFM + Deep |
| 90 | + final_logit = tf.keras.layers.add([linear_logit, fefm_logit, dnn_logit]) |
| 91 | + elif len(dnn_hidden_units) == 0 and use_fefm is True and use_linear is False: # only FEFM (shallow) |
| 92 | + final_logit = fefm_logit |
| 93 | + elif len(dnn_hidden_units) > 0 and use_fefm is False and use_linear is False: # only Deep |
| 94 | + final_logit = dnn_logit |
| 95 | + elif len(dnn_hidden_units) > 0 and use_fefm is True and use_linear is False: # FEFM + Deep # Ablation2 |
| 96 | + final_logit = tf.keras.layers.add([fefm_logit, dnn_logit]) |
| 97 | + else: |
| 98 | + raise NotImplementedError |
| 99 | + |
| 100 | + output = PredictionLayer(task)(final_logit) |
| 101 | + model = tf.keras.models.Model(inputs=inputs_list, outputs=output) |
| 102 | + return model |
0 commit comments