Skip to content

Commit a47aee1

Browse files
Lingjun LiuLingjun Liu
authored andcommitted
add attention visualisation
1 parent 048d9a3 commit a47aee1

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
from .model_utils import *
2-
from .metrics import *
2+
from .metrics import *
3+
from .subtokenizer import *
4+
from .attention_visualisation import *
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import matplotlib.pyplot as plt
2+
import tensorflow as tf
3+
def plot_attention_weights(attention, key, query):
4+
5+
'''Attention visualisation for Transformer
6+
7+
Parameters
8+
----------
9+
attention : attention weights
10+
shape of (1, number of head, length of key, length of query).
11+
12+
key : key for attention computation
13+
a list of values which would be shown as xtick labels
14+
15+
value : value for attention computation
16+
a list of values which would be shown as ytick labels
17+
18+
'''
19+
20+
21+
fig = plt.figure(figsize=(16, 8))
22+
23+
attention = tf.squeeze(attention, axis=0)
24+
25+
for head in range(attention.shape[0]):
26+
ax = fig.add_subplot(attention.shape[0]//2, 2, head+1)
27+
ax.matshow(attention[head], cmap='viridis')
28+
fontdict = {'fontsize': 12}
29+
ax.set_xticks(range(len(key)))
30+
ax.set_yticks(range(len(query)))
31+
32+
# ax.set_ylim(len(query)-1.5, -0.5)
33+
ax.set_xticklabels(
34+
[str(i) for i in key],
35+
fontdict=fontdict, rotation=90)
36+
37+
ax.set_yticklabels([str(i) for i in query], fontdict=fontdict)
38+
39+
ax.set_xlabel('Head {}'.format(head+1), fontdict = fontdict)
40+
plt.tight_layout()
41+
plt.show()

0 commit comments

Comments
 (0)