1
1
import matplotlib .pyplot as plt
2
2
import tensorflow as tf
3
- def plot_attention_weights (attention , key , query ):
4
3
4
+
5
+ def plot_attention_weights (attention , key , query ):
5
6
'''Attention visualisation for Transformer
6
7
7
8
Parameters
@@ -17,25 +18,22 @@ def plot_attention_weights(attention, key, query):
17
18
18
19
'''
19
20
20
-
21
21
fig = plt .figure (figsize = (16 , 8 ))
22
22
23
23
attention = tf .squeeze (attention , axis = 0 )
24
-
24
+
25
25
for head in range (attention .shape [0 ]):
26
- ax = fig .add_subplot (attention .shape [0 ]// 2 , 2 , head + 1 )
26
+ ax = fig .add_subplot (attention .shape [0 ] // 2 , 2 , head + 1 )
27
27
ax .matshow (attention [head ], cmap = 'viridis' )
28
28
fontdict = {'fontsize' : 12 }
29
29
ax .set_xticks (range (len (key )))
30
30
ax .set_yticks (range (len (query )))
31
31
32
32
# ax.set_ylim(len(query)-1.5, -0.5)
33
- ax .set_xticklabels (
34
- [str (i ) for i in key ],
35
- fontdict = fontdict , rotation = 90 )
33
+ ax .set_xticklabels ([str (i ) for i in key ], fontdict = fontdict , rotation = 90 )
36
34
37
35
ax .set_yticklabels ([str (i ) for i in query ], fontdict = fontdict )
38
36
39
- ax .set_xlabel ('Head {}' .format (head + 1 ), fontdict = fontdict )
37
+ ax .set_xlabel ('Head {}' .format (head + 1 ), fontdict = fontdict )
40
38
plt .tight_layout ()
41
- plt .show ()
39
+ plt .show ()
0 commit comments