Skip to content

Commit f6b34cb

Browse files
author
浅梦
authored
add kv seq feature
2 parents 444000c + 5f43ddb commit f6b34cb

File tree

2 files changed

+105
-2
lines changed

2 files changed

+105
-2
lines changed

deepctr/inputs.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,9 @@
1313
from tensorflow.python.keras.layers import Embedding, Input, Flatten
1414
from tensorflow.python.keras.regularizers import l2
1515

16-
from .layers.sequence import SequencePoolingLayer
16+
from .layers.sequence import SequencePoolingLayer, SequenceMultiplyLayer
1717
from .layers.utils import Hash,concat_fun,Linear
1818

19-
2019
class SparseFeat(namedtuple('SparseFeat', ['name', 'dimension', 'use_hash', 'dtype','embedding_name','embedding'])):
2120
__slots__ = ()
2221

@@ -25,6 +24,16 @@ def __new__(cls, name, dimension, use_hash=False, dtype="int32", embedding_name=
2524
embedding_name = name
2625
return super(SparseFeat, cls).__new__(cls, name, dimension, use_hash, dtype, embedding_name,embedding)
2726

27+
def __hash__(self):
28+
return self.name.__hash__()
29+
30+
def __eq__(self, other):
31+
if self.name == other.name:
32+
return True
33+
return False
34+
35+
def __repr__(self):
36+
return 'SparseFeat:'+self.name
2837

2938
class DenseFeat(namedtuple('DenseFeat', ['name', 'dimension', 'dtype'])):
3039
__slots__ = ()
@@ -33,6 +42,16 @@ def __new__(cls, name, dimension=1, dtype="float32"):
3342

3443
return super(DenseFeat, cls).__new__(cls, name, dimension, dtype)
3544

45+
def __hash__(self):
46+
return self.name.__hash__()
47+
48+
def __eq__(self, other):
49+
if self.name == other.name:
50+
return True
51+
return False
52+
53+
def __repr__(self):
54+
return 'DenseFeat:'+self.name
3655

3756
class VarLenSparseFeat(namedtuple('VarLenFeat', ['name', 'dimension', 'maxlen', 'combiner', 'use_hash', 'dtype','embedding_name','embedding'])):
3857
__slots__ = ()
@@ -42,6 +61,17 @@ def __new__(cls, name, dimension, maxlen, combiner="mean", use_hash=False, dtype
4261
embedding_name = name
4362
return super(VarLenSparseFeat, cls).__new__(cls, name, dimension, maxlen, combiner, use_hash, dtype, embedding_name,embedding)
4463

64+
def __hash__(self):
65+
return self.name.__hash__()
66+
67+
def __eq__(self, other):
68+
if self.name == other.name:
69+
return True
70+
return False
71+
72+
def __repr__(self):
73+
return 'VarLenSparseFeat:'+self.name
74+
4575
def get_feature_names(feature_columns):
4676
features = build_input_features(feature_columns)
4777
return list(features.keys())
@@ -209,6 +239,30 @@ def get_varlen_pooling_list(embedding_dict, features, varlen_sparse_feature_colu
209239
pooling_vec_list.append(vec)
210240
return pooling_vec_list
211241

242+
def get_varlen_multiply_list(embedding_dict, features, varlen_sparse_feature_columns_name_dict):
243+
multiply_vec_list = []
244+
print(embedding_dict)
245+
for key_feature in varlen_sparse_feature_columns_name_dict:
246+
for value_feature in varlen_sparse_feature_columns_name_dict[key_feature]:
247+
key_feature_length_name = key_feature.name + '_seq_length'
248+
if isinstance(value_feature, VarLenSparseFeat):
249+
value_input = embedding_dict[value_feature.name]
250+
elif isinstance(value_feature, DenseFeat):
251+
value_input = features[value_feature.name]
252+
else:
253+
raise TypeError("Invalid feature column type,got",type(value_feature))
254+
if key_feature_length_name in features:
255+
varlen_vec = SequenceMultiplyLayer(supports_masking=False)(
256+
[embedding_dict[key_feature.name], features[key_feature_length_name], value_input])
257+
vec = SequencePoolingLayer('sum', supports_masking=False)(
258+
[varlen_vec, features[key_feature_length_name]])
259+
else:
260+
varlen_vec = SequenceMultiplyLayer(supports_masking=True)(
261+
[embedding_dict[key_feature.name], value_input])
262+
vec = SequencePoolingLayer('sum', supports_masking=True)( varlen_vec)
263+
multiply_vec_list.append(vec)
264+
return multiply_vec_list
265+
212266
def get_dense_input(features,feature_columns):
213267
dense_feature_columns = list(filter(lambda x:isinstance(x,DenseFeat),feature_columns)) if feature_columns else []
214268
dense_input_list = []

deepctr/layers/sequence.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,3 +739,52 @@ def get_config(self, ):
739739
config = {'k': self.k, 'axis': self.axis}
740740
base_config = super(KMaxPooling, self).get_config()
741741
return dict(list(base_config.items()) + list(config.items()))
742+
743+
744+
class SequenceMultiplyLayer(Layer):
745+
746+
def __init__(self, supports_masking, **kwargs):
747+
super(SequenceMultiplyLayer, self).__init__(**kwargs)
748+
self.supports_masking = supports_masking
749+
750+
def build(self, input_shape):
751+
if not self.supports_masking:
752+
self.seq_len_max = int(input_shape[0][1])
753+
super(SequenceMultiplyLayer, self).build(
754+
input_shape) # Be sure to call this somewhere!
755+
756+
def call(self, input_list, mask=None, **kwargs):
757+
if self.supports_masking:
758+
if mask is None:
759+
raise ValueError(
760+
"When supports_masking=True,input must support masking")
761+
key_input, value_input = input_list
762+
mask = tf.cast(mask[0], tf.float32)
763+
mask = tf.expand_dims(mask, axis=2)
764+
else:
765+
key_input, key_length_input, value_input = input_list
766+
mask = tf.sequence_mask(key_length_input,
767+
self.seq_len_max, dtype=tf.float32)
768+
mask = tf.transpose(mask, (0, 2, 1))
769+
770+
embedding_size = key_input.shape[-1]
771+
mask = tf.tile(mask, [1, 1, embedding_size])
772+
key_input *= mask
773+
if len(tf.shape(value_input)) == 2:
774+
value_input = tf.expand_dims(value_input, axis=2)
775+
value_input = tf.tile(value_input, [1, 1, embedding_size])
776+
return tf.multiply(key_input,value_input)
777+
778+
def compute_output_shape(self, input_shape):
779+
return input_shape[0]
780+
781+
def compute_mask(self, inputs, mask):
782+
if self.supports_masking:
783+
return mask[0]
784+
else:
785+
return None
786+
787+
def get_config(self, ):
788+
config = {'supports_masking': self.supports_masking}
789+
base_config = super(SequenceMultiplyLayer, self).get_config()
790+
return dict(list(base_config.items()) + list(config.items()))

0 commit comments

Comments
 (0)