Skip to content

Commit 3c4cae1

Browse files
Lingjun LiuLingjun Liu
authored andcommitted
add decoder part attention visualisation
1 parent a47aee1 commit 3c4cae1

File tree

4 files changed

+79
-13
lines changed

4 files changed

+79
-13
lines changed

tensorlayer/models/transformer/attention_layer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def forward(self, x, y, mask, cache=None):
124124
125125
Returns:
126126
Attention layer output with shape [batch_size, length_x, hidden_size]
127+
Attention weights with shape [batch_size, number_of_head, length_x, length_y]
127128
"""
128129
# Linearly project the query (q), key (k) and value (v) using different
129130
# learned projections. This is in preparation of splitting them into

tensorlayer/models/transformer/transformer.py

Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,55 @@ def forward(self, inputs, targets=None):
7777
training: boolean, whether in training mode or not.
7878
7979
Returns:
80-
If targets is defined, then return logits for each word in the target
81-
sequence. float tensor with shape [batch_size, target_length, vocab_size]
82-
If target is none, then generate output sequence one token at a time.
83-
returns a dictionary {
84-
outputs: [batch_size, decoded length]
85-
scores: [batch_size, float]}
80+
If targets is defined:
81+
Logits for each word in the target sequence:
82+
float tensor with shape [batch_size, target_length, vocab_size]
83+
Self-attention weights for encoder part:
84+
a dictionary of float tensors {
85+
"layer_0": [batch_size, number_of_heads, source_length, source_length],
86+
"layer_1": [batch_size, number_of_heads, source_length, source_length],
87+
...
88+
}
89+
Weights for decoder part:
90+
a dictionary of dictionary of float tensors {
91+
"self": {
92+
"layer_0": [batch_size, number_of_heads, target_length, target_length],
93+
"layer_1": [batch_size, number_of_heads, target_length, target_length],
94+
...
95+
}
96+
"enc_dec": {
97+
"layer_0": [batch_size, number_of_heads, source_length, target_length],
98+
"layer_1": [batch_size, number_of_heads, source_length, target_length],
99+
...
100+
}
101+
}
102+
103+
If target is none:
104+
Auto-regressive beam-search decoding to generate output each one time step:
105+
a dictionary {
106+
outputs: [batch_size, decoded length]
107+
scores: [batch_size, float]}
108+
}
109+
Weights for decoder part:
110+
a dictionary of dictionary of float tensors {
111+
"self": {
112+
"layer_0": [batch_size, number_of_heads, target_length, target_length],
113+
"layer_1": [batch_size, number_of_heads, target_length, target_length],
114+
...
115+
}
116+
"enc_dec": {
117+
"layer_0": [batch_size, number_of_heads, source_length, target_length],
118+
"layer_1": [batch_size, number_of_heads, source_length, target_length],
119+
...
120+
}
121+
}
122+
Self-attention weights for encoder part:
123+
a dictionary of float tensors {
124+
"layer_0": [batch_size, number_of_heads, source_length, source_length],
125+
"layer_1": [batch_size, number_of_heads, source_length, source_length],
126+
...
127+
}
128+
86129
"""
87130
# # Variance scaling is used here because it seems to work in many problems.
88131
# # Other reasonable initializers may also work just as well.
@@ -118,6 +161,7 @@ def encode(self, inputs, attention_bias):
118161
119162
Returns:
120163
float tensor with shape [batch_size, input_length, hidden_size]
164+
121165
"""
122166

123167
# Prepare inputs to the layer stack by adding positional encodings and
@@ -223,7 +267,12 @@ def symbols_to_logits_fn(ids, i, cache):
223267
return symbols_to_logits_fn, weights
224268

225269
def predict(self, encoder_outputs, encoder_decoder_attention_bias):
226-
"""Return predicted sequence."""
270+
"""
271+
272+
Return predicted sequence, and decoder attention weights.
273+
274+
275+
"""
227276
batch_size = tf.shape(encoder_outputs)[0]
228277
input_length = tf.shape(encoder_outputs)[1]
229278
max_decode_length = input_length + self.params.extra_decode_length
@@ -263,7 +312,15 @@ def predict(self, encoder_outputs, encoder_decoder_attention_bias):
263312
top_decoded_ids = decoded_ids[:, 0, 1:]
264313
top_scores = scores[:, 0]
265314

266-
return {"outputs": top_decoded_ids, "scores": top_scores}, weights
315+
# post-process the weight attention
316+
for i, weight in enumerate(weights):
317+
if (i == 0):
318+
w = weight
319+
else:
320+
for k in range(len(w['self'])):
321+
w['self']['layer_%d' % k] = tf.concat([w['self']['layer_%d' % k], weight['self']['layer_%d' % k]], 3)
322+
w['enc_dec']['layer_%d' % k] = tf.concat([w['enc_dec']['layer_%d' % k], weight['enc_dec']['layer_%d' % k]], 2)
323+
return {"outputs": top_decoded_ids, "scores": top_scores}, w
267324

268325

269326
class LayerNormalization(tl.layers.Layer):
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
from .model_utils import *
22
from .metrics import *
3-
from .subtokenizer import *
43
from .attention_visualisation import *

tests/models/test_transformer.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class TINY_PARAMS(object):
4141

4242
# Default prediction params
4343
extra_decode_length=5
44-
beam_size=2
44+
beam_size=1
4545
alpha=0.6 # used to calculate length normalization in beam search
4646

4747

@@ -64,7 +64,7 @@ def setUpClass(cls):
6464

6565
assert cls.src_len == cls.tgt_len
6666

67-
cls.num_epochs = 20
67+
cls.num_epochs = 100
6868
cls.n_step = cls.src_len // cls.batch_size
6969

7070
@classmethod
@@ -108,25 +108,34 @@ def test_basic_simpleSeq2Seq(self):
108108
model_.eval()
109109
[prediction, weights_decoder], weights_encoder = model_(inputs = test_sample)
110110

111+
111112
print("Prediction: >>>>> ", prediction["outputs"], "\n Target: >>>>> ", trainY[0:2, :], "\n\n")
112113

113114
print('Epoch [{}/{}]: loss {:.4f}'.format(epoch + 1, self.num_epochs, total_loss / n_iter))
114115

115116

116-
# visualise the self-attention weights at encoder
117+
# visualise the self-attention weights at encoder during training
117118
trainX, trainY = shuffle(self.trainX, self.trainY)
118119
X = [trainX[0]]
119120
Y = [trainY[0]]
120121
logits, weights_encoder, weights_decoder = model_(inputs = X, targets = Y)
121122
attention_visualisation.plot_attention_weights(weights_encoder["layer_0"], X[0].numpy(), X[0].numpy())
122123

123-
# visualise the self-attention weights at encoder
124+
# visualise the encoder-decoder-attention weights at decoder during training
124125
trainX, trainY = shuffle(self.trainX, self.trainY)
125126
X = [trainX[0]]
126127
Y = [trainY[0]]
127128
logits, weights_encoder, weights_decoder = model_(inputs = X, targets = Y)
128129
attention_visualisation.plot_attention_weights(weights_decoder["enc_dec"]["layer_0"], X[0].numpy(), Y[0])
129130

131+
# visualise the encoder-decoder-attention weights at decoder during inference
132+
trainX, trainY = shuffle(self.trainX, self.trainY)
133+
X = [trainX[0]]
134+
# Y = [trainY[0]]
135+
model_.eval()
136+
[prediction, weights_decoder], weights_encoder = model_(inputs = X)
137+
# print(X[0].numpy(), prediction["outputs"][0].numpy())
138+
attention_visualisation.plot_attention_weights(weights_decoder["enc_dec"]["layer_0"], X[0].numpy(), prediction["outputs"][0].numpy())
130139

131140
if __name__ == '__main__':
132141
unittest.main()

0 commit comments

Comments
 (0)