Skip to content

Commit f5438a7

Browse files
Lingjun LiuLingjun Liu
authored andcommitted
documentation
1 parent 4d2e19e commit f5438a7

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed
Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import matplotlib.pyplot as plt
22
import tensorflow as tf
3-
def plot_attention_weights(attention, key, query):
43

4+
5+
def plot_attention_weights(attention, key, query):
56
'''Attention visualisation for Transformer
67
78
Parameters
@@ -17,25 +18,22 @@ def plot_attention_weights(attention, key, query):
1718
1819
'''
1920

20-
2121
fig = plt.figure(figsize=(16, 8))
2222

2323
attention = tf.squeeze(attention, axis=0)
24-
24+
2525
for head in range(attention.shape[0]):
26-
ax = fig.add_subplot(attention.shape[0]//2, 2, head+1)
26+
ax = fig.add_subplot(attention.shape[0] // 2, 2, head + 1)
2727
ax.matshow(attention[head], cmap='viridis')
2828
fontdict = {'fontsize': 12}
2929
ax.set_xticks(range(len(key)))
3030
ax.set_yticks(range(len(query)))
3131

3232
# ax.set_ylim(len(query)-1.5, -0.5)
33-
ax.set_xticklabels(
34-
[str(i) for i in key],
35-
fontdict=fontdict, rotation=90)
33+
ax.set_xticklabels([str(i) for i in key], fontdict=fontdict, rotation=90)
3634

3735
ax.set_yticklabels([str(i) for i in query], fontdict=fontdict)
3836

39-
ax.set_xlabel('Head {}'.format(head+1), fontdict = fontdict)
37+
ax.set_xlabel('Head {}'.format(head + 1), fontdict=fontdict)
4038
plt.tight_layout()
41-
plt.show()
39+
plt.show()

0 commit comments

Comments
 (0)