Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.

Commit 33b20f1

Browse files
authored
Use the mask_tensor inside the example code.
1 parent 40b799b commit 33b20f1

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

rfcs/20200616-keras-multihead-attention.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ class MultiHeadAttention(tf.keras.layers.Layer):
134134
>>> target = tf.keras.Input(shape=[8, 16])
135135
>>> source = tf.keras.Input(shape=[4, 16])
136136
>>> mask_tensor = tf.keras.Input(shape=[8, 4])
137-
>>> output_tensor, weights = layer([target, source])
137+
>>> output_tensor, weights = layer([target, source], attention_mask=mask_tensor)
138138
>>> print(output_tensor.shape), print(weights.shape)
139139
(None, 8, 16) (None, 2, 8, 4)
140140

0 commit comments

Comments
 (0)