@@ -46,7 +46,7 @@ def __init__(self, mode='mean', supports_masking=False, **kwargs):
46
46
if mode not in ['sum' , 'mean' , 'max' ]:
47
47
raise ValueError ("mode must be sum or mean" )
48
48
self .mode = mode
49
- self .eps = 1e-8
49
+ self .eps = tf . constant ( 1e-8 , tf . float32 )
50
50
super (SequencePoolingLayer , self ).__init__ (** kwargs )
51
51
52
52
self .supports_masking = supports_masking
@@ -85,7 +85,7 @@ def call(self, seq_value_len_list, mask=None, **kwargs):
85
85
hist = reduce_sum (hist , 1 , keep_dims = False )
86
86
87
87
if self .mode == "mean" :
88
- hist = div (hist , user_behavior_length + self .eps )
88
+ hist = div (hist , tf . cast ( user_behavior_length , tf . float32 ) + self .eps )
89
89
90
90
hist = tf .expand_dims (hist , axis = 1 )
91
91
return hist
@@ -105,6 +105,83 @@ def get_config(self, ):
105
105
return dict (list (base_config .items ()) + list (config .items ()))
106
106
107
107
108
+ class WeightedSequenceLayer (Layer ):
109
+ """The WeightedSequenceLayer is used to apply weight score on variable-length sequence feature/multi-value feature.
110
+
111
+ Input shape
112
+ - A list of two tensor [seq_value,seq_len,seq_weight]
113
+
114
+ - seq_value is a 3D tensor with shape: ``(batch_size, T, embedding_size)``
115
+
116
+ - seq_len is a 2D tensor with shape : ``(batch_size, 1)``,indicate valid length of each sequence.
117
+
118
+ - seq_weight is a 3D tensor with shape: ``(batch_size, T, 1)``
119
+
120
+ Output shape
121
+ - 3D tensor with shape: ``(batch_size, T, embedding_size)``.
122
+
123
+ Arguments
124
+ - **weight_normalization**: bool.Whether normalize the weight socre before applying to sequence.
125
+
126
+ - **supports_masking**:If True,the input need to support masking.
127
+ """
128
+
129
+ def __init__ (self ,weight_normalization = False , supports_masking = False , ** kwargs ):
130
+ super (WeightedSequenceLayer , self ).__init__ (** kwargs )
131
+ self .weight_normalization = weight_normalization
132
+ self .supports_masking = supports_masking
133
+
134
+ def build (self , input_shape ):
135
+ if not self .supports_masking :
136
+ self .seq_len_max = int (input_shape [0 ][1 ])
137
+ super (WeightedSequenceLayer , self ).build (
138
+ input_shape ) # Be sure to call this somewhere!
139
+
140
+ def call (self , input_list , mask = None , ** kwargs ):
141
+ if self .supports_masking :
142
+ if mask is None :
143
+ raise ValueError (
144
+ "When supports_masking=True,input must support masking" )
145
+ key_input , value_input = input_list
146
+ mask = tf .expand_dims (mask [0 ], axis = 2 )
147
+ else :
148
+ key_input , key_length_input , value_input = input_list
149
+ mask = tf .sequence_mask (key_length_input ,
150
+ self .seq_len_max , dtype = tf .bool )
151
+ mask = tf .transpose (mask , (0 , 2 , 1 ))
152
+
153
+ embedding_size = key_input .shape [- 1 ]
154
+
155
+ if self .weight_normalization :
156
+ paddings = tf .ones_like (value_input ) * (- 2 ** 32 + 1 )
157
+ else :
158
+ paddings = tf .zeros_like (value_input )
159
+ value_input = tf .where (mask , value_input , paddings )
160
+
161
+ if self .weight_normalization :
162
+ value_input = softmax (value_input ,dim = 1 )
163
+
164
+
165
+ if len (value_input .shape ) == 2 :
166
+ value_input = tf .expand_dims (value_input , axis = 2 )
167
+ value_input = tf .tile (value_input , [1 , 1 , embedding_size ])
168
+
169
+ return tf .multiply (key_input ,value_input )
170
+
171
+ def compute_output_shape (self , input_shape ):
172
+ return input_shape [0 ]
173
+
174
+ def compute_mask (self , inputs , mask ):
175
+ if self .supports_masking :
176
+ return mask [0 ]
177
+ else :
178
+ return None
179
+
180
+ def get_config (self , ):
181
+ config = {'supports_masking' : self .supports_masking }
182
+ base_config = super (WeightedSequenceLayer , self ).get_config ()
183
+ return dict (list (base_config .items ()) + list (config .items ()))
184
+
108
185
class AttentionSequencePoolingLayer (Layer ):
109
186
"""The Attentional sequence pooling operation used in DIN.
110
187
@@ -741,50 +818,3 @@ def get_config(self, ):
741
818
return dict (list (base_config .items ()) + list (config .items ()))
742
819
743
820
744
- class SequenceMultiplyLayer (Layer ):
745
-
746
- def __init__ (self , supports_masking , ** kwargs ):
747
- super (SequenceMultiplyLayer , self ).__init__ (** kwargs )
748
- self .supports_masking = supports_masking
749
-
750
- def build (self , input_shape ):
751
- if not self .supports_masking :
752
- self .seq_len_max = int (input_shape [0 ][1 ])
753
- super (SequenceMultiplyLayer , self ).build (
754
- input_shape ) # Be sure to call this somewhere!
755
-
756
- def call (self , input_list , mask = None , ** kwargs ):
757
- if self .supports_masking :
758
- if mask is None :
759
- raise ValueError (
760
- "When supports_masking=True,input must support masking" )
761
- key_input , value_input = input_list
762
- mask = tf .cast (mask [0 ], tf .float32 )
763
- mask = tf .expand_dims (mask , axis = 2 )
764
- else :
765
- key_input , key_length_input , value_input = input_list
766
- mask = tf .sequence_mask (key_length_input ,
767
- self .seq_len_max , dtype = tf .float32 )
768
- mask = tf .transpose (mask , (0 , 2 , 1 ))
769
-
770
- embedding_size = key_input .shape [- 1 ]
771
- mask = tf .tile (mask , [1 , 1 , embedding_size ])
772
- key_input *= mask
773
- if len (tf .shape (value_input )) == 2 :
774
- value_input = tf .expand_dims (value_input , axis = 2 )
775
- value_input = tf .tile (value_input , [1 , 1 , embedding_size ])
776
- return tf .multiply (key_input ,value_input )
777
-
778
- def compute_output_shape (self , input_shape ):
779
- return input_shape [0 ]
780
-
781
- def compute_mask (self , inputs , mask ):
782
- if self .supports_masking :
783
- return mask [0 ]
784
- else :
785
- return None
786
-
787
- def get_config (self , ):
788
- config = {'supports_masking' : self .supports_masking }
789
- base_config = super (SequenceMultiplyLayer , self ).get_config ()
790
- return dict (list (base_config .items ()) + list (config .items ()))
0 commit comments