@@ -77,12 +77,55 @@ def forward(self, inputs, targets=None):
77
77
training: boolean, whether in training mode or not.
78
78
79
79
Returns:
80
- If targets is defined, then return logits for each word in the target
81
- sequence. float tensor with shape [batch_size, target_length, vocab_size]
82
- If target is none, then generate output sequence one token at a time.
83
- returns a dictionary {
84
- outputs: [batch_size, decoded length]
85
- scores: [batch_size, float]}
80
+ If targets is defined:
81
+ Logits for each word in the target sequence:
82
+ float tensor with shape [batch_size, target_length, vocab_size]
83
+ Self-attention weights for encoder part:
84
+ a dictionary of float tensors {
85
+ "layer_0": [batch_size, number_of_heads, source_length, source_length],
86
+ "layer_1": [batch_size, number_of_heads, source_length, source_length],
87
+ ...
88
+ }
89
+ Weights for decoder part:
90
+ a dictionary of dictionary of float tensors {
91
+ "self": {
92
+ "layer_0": [batch_size, number_of_heads, target_length, target_length],
93
+ "layer_1": [batch_size, number_of_heads, target_length, target_length],
94
+ ...
95
+ }
96
+ "enc_dec": {
97
+ "layer_0": [batch_size, number_of_heads, source_length, target_length],
98
+ "layer_1": [batch_size, number_of_heads, source_length, target_length],
99
+ ...
100
+ }
101
+ }
102
+
103
+ If target is none:
104
+ Auto-regressive beam-search decoding to generate output each one time step:
105
+ a dictionary {
106
+ outputs: [batch_size, decoded length]
107
+ scores: [batch_size, float]}
108
+ }
109
+ Weights for decoder part:
110
+ a dictionary of dictionary of float tensors {
111
+ "self": {
112
+ "layer_0": [batch_size, number_of_heads, target_length, target_length],
113
+ "layer_1": [batch_size, number_of_heads, target_length, target_length],
114
+ ...
115
+ }
116
+ "enc_dec": {
117
+ "layer_0": [batch_size, number_of_heads, source_length, target_length],
118
+ "layer_1": [batch_size, number_of_heads, source_length, target_length],
119
+ ...
120
+ }
121
+ }
122
+ Self-attention weights for encoder part:
123
+ a dictionary of float tensors {
124
+ "layer_0": [batch_size, number_of_heads, source_length, source_length],
125
+ "layer_1": [batch_size, number_of_heads, source_length, source_length],
126
+ ...
127
+ }
128
+
86
129
"""
87
130
# # Variance scaling is used here because it seems to work in many problems.
88
131
# # Other reasonable initializers may also work just as well.
@@ -118,6 +161,7 @@ def encode(self, inputs, attention_bias):
118
161
119
162
Returns:
120
163
float tensor with shape [batch_size, input_length, hidden_size]
164
+
121
165
"""
122
166
123
167
# Prepare inputs to the layer stack by adding positional encodings and
@@ -223,7 +267,12 @@ def symbols_to_logits_fn(ids, i, cache):
223
267
return symbols_to_logits_fn , weights
224
268
225
269
def predict (self , encoder_outputs , encoder_decoder_attention_bias ):
226
- """Return predicted sequence."""
270
+ """
271
+
272
+ Return predicted sequence, and decoder attention weights.
273
+
274
+
275
+ """
227
276
batch_size = tf .shape (encoder_outputs )[0 ]
228
277
input_length = tf .shape (encoder_outputs )[1 ]
229
278
max_decode_length = input_length + self .params .extra_decode_length
@@ -263,7 +312,15 @@ def predict(self, encoder_outputs, encoder_decoder_attention_bias):
263
312
top_decoded_ids = decoded_ids [:, 0 , 1 :]
264
313
top_scores = scores [:, 0 ]
265
314
266
- return {"outputs" : top_decoded_ids , "scores" : top_scores }, weights
315
+ # post-process the weight attention
316
+ for i , weight in enumerate (weights ):
317
+ if (i == 0 ):
318
+ w = weight
319
+ else :
320
+ for k in range (len (w ['self' ])):
321
+ w ['self' ]['layer_%d' % k ] = tf .concat ([w ['self' ]['layer_%d' % k ], weight ['self' ]['layer_%d' % k ]], 3 )
322
+ w ['enc_dec' ]['layer_%d' % k ] = tf .concat ([w ['enc_dec' ]['layer_%d' % k ], weight ['enc_dec' ]['layer_%d' % k ]], 2 )
323
+ return {"outputs" : top_decoded_ids , "scores" : top_scores }, w
267
324
268
325
269
326
class LayerNormalization (tl .layers .Layer ):
0 commit comments