Skip to content

Commit 048d9a3

Browse files
Lingjun LiuLingjun Liu
authored andcommitted
add attention visualisation
1 parent 3ef8d8b commit 048d9a3

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

tests/models/test_transformer.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
import os
55
import unittest
66

7-
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
8-
97
import numpy as np
108
import tensorflow as tf
119
import tensorlayer as tl
@@ -14,8 +12,8 @@
1412
from tensorlayer.models.transformer import Transformer
1513
from tests.utils import CustomTestCase
1614
from tensorlayer.models.transformer.utils import metrics
17-
from tensorlayer.cost import cross_entropy_seq
1815
from tensorlayer.optimizers import lazyAdam as optimizer
16+
from tensorlayer.models.transformer.utils import attention_visualisation
1917
import time
2018

2119

@@ -51,7 +49,7 @@ class Model_SEQ2SEQ_Test(CustomTestCase):
5149

5250
@classmethod
5351
def setUpClass(cls):
54-
cls.batch_size = 16
52+
cls.batch_size = 50
5553

5654
cls.embedding_size = 32
5755
cls.dec_seq_length = 5
@@ -66,7 +64,7 @@ def setUpClass(cls):
6664

6765
assert cls.src_len == cls.tgt_len
6866

69-
cls.num_epochs = 1000
67+
cls.num_epochs = 20
7068
cls.n_step = cls.src_len // cls.batch_size
7169

7270
@classmethod
@@ -99,8 +97,8 @@ def test_basic_simpleSeq2Seq(self):
9997

10098
grad = tape.gradient(loss, model_.all_weights)
10199
optimizer.apply_gradients(zip(grad, model_.all_weights))
102-
103100

101+
104102
total_loss += loss
105103
n_iter += 1
106104
print(time.time()-t)
@@ -115,5 +113,20 @@ def test_basic_simpleSeq2Seq(self):
115113
print('Epoch [{}/{}]: loss {:.4f}'.format(epoch + 1, self.num_epochs, total_loss / n_iter))
116114

117115

116+
# visualise the self-attention weights at encoder
117+
trainX, trainY = shuffle(self.trainX, self.trainY)
118+
X = [trainX[0]]
119+
Y = [trainY[0]]
120+
logits, weights_encoder, weights_decoder = model_(inputs = X, targets = Y)
121+
attention_visualisation.plot_attention_weights(weights_encoder["layer_0"], X[0].numpy(), X[0].numpy())
122+
123+
# visualise the self-attention weights at encoder
124+
trainX, trainY = shuffle(self.trainX, self.trainY)
125+
X = [trainX[0]]
126+
Y = [trainY[0]]
127+
logits, weights_encoder, weights_decoder = model_(inputs = X, targets = Y)
128+
attention_visualisation.plot_attention_weights(weights_decoder["enc_dec"]["layer_0"], X[0].numpy(), Y[0])
129+
130+
118131
if __name__ == '__main__':
119132
unittest.main()

0 commit comments

Comments
 (0)