4
4
import os
5
5
import unittest
6
6
7
- os .environ ['TF_CPP_MIN_LOG_LEVEL' ] = '3'
8
-
9
7
import numpy as np
10
8
import tensorflow as tf
11
9
import tensorlayer as tl
14
12
from tensorlayer .models .transformer import Transformer
15
13
from tests .utils import CustomTestCase
16
14
from tensorlayer .models .transformer .utils import metrics
17
- from tensorlayer .cost import cross_entropy_seq
18
15
from tensorlayer .optimizers import lazyAdam as optimizer
16
+ from tensorlayer .models .transformer .utils import attention_visualisation
19
17
import time
20
18
21
19
@@ -51,7 +49,7 @@ class Model_SEQ2SEQ_Test(CustomTestCase):
51
49
52
50
@classmethod
53
51
def setUpClass (cls ):
54
- cls .batch_size = 16
52
+ cls .batch_size = 50
55
53
56
54
cls .embedding_size = 32
57
55
cls .dec_seq_length = 5
@@ -66,7 +64,7 @@ def setUpClass(cls):
66
64
67
65
assert cls .src_len == cls .tgt_len
68
66
69
- cls .num_epochs = 1000
67
+ cls .num_epochs = 20
70
68
cls .n_step = cls .src_len // cls .batch_size
71
69
72
70
@classmethod
@@ -99,8 +97,8 @@ def test_basic_simpleSeq2Seq(self):
99
97
100
98
grad = tape .gradient (loss , model_ .all_weights )
101
99
optimizer .apply_gradients (zip (grad , model_ .all_weights ))
102
-
103
100
101
+
104
102
total_loss += loss
105
103
n_iter += 1
106
104
print (time .time ()- t )
@@ -115,5 +113,20 @@ def test_basic_simpleSeq2Seq(self):
115
113
print ('Epoch [{}/{}]: loss {:.4f}' .format (epoch + 1 , self .num_epochs , total_loss / n_iter ))
116
114
117
115
116
+ # visualise the self-attention weights at encoder
117
+ trainX , trainY = shuffle (self .trainX , self .trainY )
118
+ X = [trainX [0 ]]
119
+ Y = [trainY [0 ]]
120
+ logits , weights_encoder , weights_decoder = model_ (inputs = X , targets = Y )
121
+ attention_visualisation .plot_attention_weights (weights_encoder ["layer_0" ], X [0 ].numpy (), X [0 ].numpy ())
122
+
123
+ # visualise the self-attention weights at encoder
124
+ trainX , trainY = shuffle (self .trainX , self .trainY )
125
+ X = [trainX [0 ]]
126
+ Y = [trainY [0 ]]
127
+ logits , weights_encoder , weights_decoder = model_ (inputs = X , targets = Y )
128
+ attention_visualisation .plot_attention_weights (weights_decoder ["enc_dec" ]["layer_0" ], X [0 ].numpy (), Y [0 ])
129
+
130
+
118
131
if __name__ == '__main__' :
119
132
unittest .main ()
0 commit comments