|
| 1 | +# -*- coding:utf-8 -*- |
| 2 | +""" |
| 3 | +
|
| 4 | +Author: |
| 5 | + |
| 6 | +
|
| 7 | +Reference: |
| 8 | + [1] Xiao J, Ye H, He X, et al. Attentional factorization machines: Learning the weight of feature interactions via attention networks[J]. arXiv preprint arXiv:1708.04617, 2017. |
| 9 | + (https://arxiv.org/abs/1708.04617) |
| 10 | +
|
| 11 | +""" |
| 12 | +import tensorflow as tf |
| 13 | + |
| 14 | +from ..feature_column import get_linear_logit, input_from_feature_columns |
| 15 | +from ..utils import deepctr_model_fn, DNN_SCOPE_NAME, variable_scope |
| 16 | +from ...layers.interaction import AFMLayer, FM |
| 17 | +from ...layers.utils import concat_func |
| 18 | + |
| 19 | + |
| 20 | +def AFMEstimator(linear_feature_columns, dnn_feature_columns, use_attention=True, attention_factor=8, |
| 21 | + l2_reg_linear=1e-5, l2_reg_embedding=1e-5, l2_reg_att=1e-5, afm_dropout=0, seed=1024, |
| 22 | + task='binary', model_dir=None, config=None, linear_optimizer='Ftrl', |
| 23 | + dnn_optimizer='Adagrad', training_chief_hooks=None): |
| 24 | + """Instantiates the Attentional Factorization Machine architecture. |
| 25 | +
|
| 26 | + :param linear_feature_columns: An iterable containing all the features used by linear part of the model. |
| 27 | + :param dnn_feature_columns: An iterable containing all the features used by deep part of the model. |
| 28 | + :param use_attention: bool,whether use attention or not,if set to ``False``.it is the same as **standard Factorization Machine** |
| 29 | + :param attention_factor: positive integer,units in attention net |
| 30 | + :param l2_reg_linear: float. L2 regularizer strength applied to linear part |
| 31 | + :param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector |
| 32 | + :param l2_reg_att: float. L2 regularizer strength applied to attention net |
| 33 | + :param afm_dropout: float in [0,1), Fraction of the attention net output units to dropout. |
| 34 | + :param seed: integer ,to use as random seed. |
| 35 | + :param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss |
| 36 | + :param model_dir: Directory to save model parameters, graph and etc. This can |
| 37 | + also be used to load checkpoints from the directory into a estimator |
| 38 | + to continue training a previously saved model. |
| 39 | + :param config: tf.RunConfig object to configure the runtime settings. |
| 40 | + :param linear_optimizer: An instance of `tf.Optimizer` used to apply gradients to |
| 41 | + the linear part of the model. Defaults to FTRL optimizer. |
| 42 | + :param dnn_optimizer: An instance of `tf.Optimizer` used to apply gradients to |
| 43 | + the deep part of the model. Defaults to Adagrad optimizer. |
| 44 | + :param training_chief_hooks: Iterable of `tf.train.SessionRunHook` objects to |
| 45 | + run on the chief worker during training. |
| 46 | + :return: A Tensorflow Estimator instance. |
| 47 | +
|
| 48 | + """ |
| 49 | + |
| 50 | + def _model_fn(features, labels, mode, config): |
| 51 | + train_flag = (mode == tf.estimator.ModeKeys.TRAIN) |
| 52 | + |
| 53 | + linear_logits = get_linear_logit(features, linear_feature_columns, l2_reg_linear=l2_reg_linear) |
| 54 | + |
| 55 | + with variable_scope(DNN_SCOPE_NAME): |
| 56 | + sparse_embedding_list, dense_value_list = input_from_feature_columns(features, dnn_feature_columns, |
| 57 | + l2_reg_embedding=l2_reg_embedding) |
| 58 | + if use_attention: |
| 59 | + |
| 60 | + fm_logit = AFMLayer(attention_factor, l2_reg_att, afm_dropout, |
| 61 | + seed)(sparse_embedding_list, training=train_flag) |
| 62 | + else: |
| 63 | + fm_logit = FM()(concat_func(sparse_embedding_list, axis=1)) |
| 64 | + |
| 65 | + logits = linear_logits + fm_logit |
| 66 | + |
| 67 | + return deepctr_model_fn(features, mode, logits, labels, task, linear_optimizer, dnn_optimizer, |
| 68 | + training_chief_hooks=training_chief_hooks) |
| 69 | + |
| 70 | + return tf.estimator.Estimator(_model_fn, model_dir=model_dir, config=config) |
0 commit comments