1
+ import matplotlib .pyplot as plt
2
+ import tensorflow as tf
3
+ def plot_attention_weights (attention , key , query ):
4
+
5
+ '''Attention visualisation for Transformer
6
+
7
+ Parameters
8
+ ----------
9
+ attention : attention weights
10
+ shape of (1, number of head, length of key, length of query).
11
+
12
+ key : key for attention computation
13
+ a list of values which would be shown as xtick labels
14
+
15
+ value : value for attention computation
16
+ a list of values which would be shown as ytick labels
17
+
18
+ '''
19
+
20
+
21
+ fig = plt .figure (figsize = (16 , 8 ))
22
+
23
+ attention = tf .squeeze (attention , axis = 0 )
24
+
25
+ for head in range (attention .shape [0 ]):
26
+ ax = fig .add_subplot (attention .shape [0 ]// 2 , 2 , head + 1 )
27
+ ax .matshow (attention [head ], cmap = 'viridis' )
28
+ fontdict = {'fontsize' : 12 }
29
+ ax .set_xticks (range (len (key )))
30
+ ax .set_yticks (range (len (query )))
31
+
32
+ # ax.set_ylim(len(query)-1.5, -0.5)
33
+ ax .set_xticklabels (
34
+ [str (i ) for i in key ],
35
+ fontdict = fontdict , rotation = 90 )
36
+
37
+ ax .set_yticklabels ([str (i ) for i in query ], fontdict = fontdict )
38
+
39
+ ax .set_xlabel ('Head {}' .format (head + 1 ), fontdict = fontdict )
40
+ plt .tight_layout ()
41
+ plt .show ()
0 commit comments