13
13
from tensorflow .python .keras .layers import Embedding , Input , Flatten
14
14
from tensorflow .python .keras .regularizers import l2
15
15
16
- from .layers .sequence import SequencePoolingLayer
16
+ from .layers .sequence import SequencePoolingLayer , SequenceMultiplyLayer
17
17
from .layers .utils import Hash ,concat_fun ,Linear
18
18
19
-
20
19
class SparseFeat (namedtuple ('SparseFeat' , ['name' , 'dimension' , 'use_hash' , 'dtype' ,'embedding_name' ,'embedding' ])):
21
20
__slots__ = ()
22
21
@@ -25,6 +24,16 @@ def __new__(cls, name, dimension, use_hash=False, dtype="int32", embedding_name=
25
24
embedding_name = name
26
25
return super (SparseFeat , cls ).__new__ (cls , name , dimension , use_hash , dtype , embedding_name ,embedding )
27
26
27
+ def __hash__ (self ):
28
+ return self .name .__hash__ ()
29
+
30
+ def __eq__ (self , other ):
31
+ if self .name == other .name :
32
+ return True
33
+ return False
34
+
35
+ def __repr__ (self ):
36
+ return 'SparseFeat:' + self .name
28
37
29
38
class DenseFeat (namedtuple ('DenseFeat' , ['name' , 'dimension' , 'dtype' ])):
30
39
__slots__ = ()
@@ -33,6 +42,16 @@ def __new__(cls, name, dimension=1, dtype="float32"):
33
42
34
43
return super (DenseFeat , cls ).__new__ (cls , name , dimension , dtype )
35
44
45
+ def __hash__ (self ):
46
+ return self .name .__hash__ ()
47
+
48
+ def __eq__ (self , other ):
49
+ if self .name == other .name :
50
+ return True
51
+ return False
52
+
53
+ def __repr__ (self ):
54
+ return 'DenseFeat:' + self .name
36
55
37
56
class VarLenSparseFeat (namedtuple ('VarLenFeat' , ['name' , 'dimension' , 'maxlen' , 'combiner' , 'use_hash' , 'dtype' ,'embedding_name' ,'embedding' ])):
38
57
__slots__ = ()
@@ -42,6 +61,17 @@ def __new__(cls, name, dimension, maxlen, combiner="mean", use_hash=False, dtype
42
61
embedding_name = name
43
62
return super (VarLenSparseFeat , cls ).__new__ (cls , name , dimension , maxlen , combiner , use_hash , dtype , embedding_name ,embedding )
44
63
64
+ def __hash__ (self ):
65
+ return self .name .__hash__ ()
66
+
67
+ def __eq__ (self , other ):
68
+ if self .name == other .name :
69
+ return True
70
+ return False
71
+
72
+ def __repr__ (self ):
73
+ return 'VarLenSparseFeat:' + self .name
74
+
45
75
def get_feature_names (feature_columns ):
46
76
features = build_input_features (feature_columns )
47
77
return list (features .keys ())
@@ -209,6 +239,30 @@ def get_varlen_pooling_list(embedding_dict, features, varlen_sparse_feature_colu
209
239
pooling_vec_list .append (vec )
210
240
return pooling_vec_list
211
241
242
+ def get_varlen_multiply_list (embedding_dict , features , varlen_sparse_feature_columns_name_dict ):
243
+ multiply_vec_list = []
244
+ print (embedding_dict )
245
+ for key_feature in varlen_sparse_feature_columns_name_dict :
246
+ for value_feature in varlen_sparse_feature_columns_name_dict [key_feature ]:
247
+ key_feature_length_name = key_feature .name + '_seq_length'
248
+ if isinstance (value_feature , VarLenSparseFeat ):
249
+ value_input = embedding_dict [value_feature .name ]
250
+ elif isinstance (value_feature , DenseFeat ):
251
+ value_input = features [value_feature .name ]
252
+ else :
253
+ raise TypeError ("Invalid feature column type,got" ,type (value_feature ))
254
+ if key_feature_length_name in features :
255
+ varlen_vec = SequenceMultiplyLayer (supports_masking = False )(
256
+ [embedding_dict [key_feature .name ], features [key_feature_length_name ], value_input ])
257
+ vec = SequencePoolingLayer ('sum' , supports_masking = False )(
258
+ [varlen_vec , features [key_feature_length_name ]])
259
+ else :
260
+ varlen_vec = SequenceMultiplyLayer (supports_masking = True )(
261
+ [embedding_dict [key_feature .name ], value_input ])
262
+ vec = SequencePoolingLayer ('sum' , supports_masking = True )( varlen_vec )
263
+ multiply_vec_list .append (vec )
264
+ return multiply_vec_list
265
+
212
266
def get_dense_input (features ,feature_columns ):
213
267
dense_feature_columns = list (filter (lambda x :isinstance (x ,DenseFeat ),feature_columns )) if feature_columns else []
214
268
dense_input_list = []
0 commit comments